{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "tensorflow_hub_import.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "-V0X0E7LkEa4",
        "FH3IRpYTta2v"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FH3IRpYTta2v"
      },
      "source": [
        "##### Copyright 2021 The IREE Authors"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mWGa71_Ct2ug",
        "cellView": "form"
      },
      "source": [
        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "# See https://llvm.org/LICENSE.txt for license information.\n",
        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qb3S0mSjpK7J"
      },
      "source": [
        "# IREE TensorFlow Hub Import\n",
        "\n",
        "This notebook demonstrates how to download, import, and compile models from [TensorFlow Hub](https://tfhub.dev/). It covers:\n",
        "\n",
        "* Downloading a model from TensorFlow Hub\n",
        "* Ensuring the model has serving signatures needed for import\n",
        "* Importing and compiling the model with IREE\n",
        "\n",
        "At the end of the notebook, the compilation artifacts are compressed into a .zip file for you to download and use in an application.\n",
        "\n",
        "See also https://iree.dev/guides/ml-frameworks/tensorflow/."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9rNAJKNVkKOr"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RdVc4TbOkHM2"
      },
      "source": [
        "%%capture\n",
        "!python -m pip install --pre iree-base-compiler iree-base-runtime iree-tools-tf -f https://iree.dev/pip-release-links.html"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qRwv3qI_l5O_",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "8d3bf1f1-1843-4fe9-80e0-a9fc5b194778"
      },
      "source": [
        "import os\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "import tempfile\n",
        "from IPython.display import clear_output\n",
        "\n",
        "from iree.compiler import tf as tfc\n",
        "\n",
        "# Print version information for future notebook users to reference.\n",
        "print(\"TensorFlow version: \", tf.__version__)\n",
        "\n",
        "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
        "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n",
        "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "TensorFlow version:  2.12.0\n",
            "Using artifacts directory '/tmp/iree/colab_artifacts'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZZAobcAhocFE"
      },
      "source": [
        "## Import pretrained [`mobilenet_v2`](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4) model\n",
        "\n",
        "IREE supports importing TensorFlow 2 models exported in the [SavedModel](https://www.tensorflow.org/guide/saved_model) format. This model we'll be importing is published in that format already, while other models may need to be converted first.\n",
        "\n",
        "MobileNet V2 is a family of neural network architectures for efficient on-device image classification and related tasks. This TensorFlow Hub module contains a trained instance of one particular network architecture packaged to perform image classification."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7fd0vmnloZo9",
        "outputId": "dabea3a2-d312-4729-c947-b24216a6c25b"
      },
      "source": [
        "#@title Download the pretrained model\n",
        "\n",
        "# Use the `hub` library to download the pretrained model to the local disk\n",
        "# https://www.tensorflow.org/hub/api_docs/python/hub\n",
        "HUB_PATH = \"https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4\"\n",
        "model_path = hub.resolve(HUB_PATH)\n",
        "print(f\"Downloaded model from tfhub to path: '{model_path}'\")"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloaded model from tfhub to path: '/tmp/tfhub_modules/426589ad685896ab7954855255a52db3442cb38d'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CedNRSQTOE7C"
      },
      "source": [
        "### Check for serving signatures and re-export as needed\n",
        "\n",
        "IREE's compiler tools, like TensorFlow's `saved_model_cli` and other tools, require \"serving signatures\" to be defined in SavedModels.\n",
        "\n",
        "More references:\n",
        "\n",
        "* https://www.tensorflow.org/tfx/serving/signature_defs\n",
        "* https://blog.tensorflow.org/2021/03/a-tour-of-savedmodel-signatures.html"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qiO66oEYQmsd",
        "outputId": "91f724db-01cd-4dd3-c55c-ba4431233cfa"
      },
      "source": [
        "#@title Check for serving signatures\n",
        "\n",
        "# Load the SavedModel from the local disk and check if it has serving signatures\n",
        "# https://www.tensorflow.org/guide/saved_model#loading_and_using_a_custom_model\n",
        "loaded_model = tf.saved_model.load(model_path)\n",
        "serving_signatures = list(loaded_model.signatures.keys())\n",
        "print(f\"Loaded SavedModel from '{model_path}'\")\n",
        "print(f\"Serving signatures: {serving_signatures}\")\n",
        "\n",
        "# Also check with the saved_model_cli:\n",
        "print(\"\\n---\\n\")\n",
        "print(\"Checking for signature_defs using saved_model_cli:\\n\")\n",
        "!saved_model_cli show --dir {model_path} --tag_set serve --signature_def serving_default"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loaded SavedModel from '/tmp/tfhub_modules/426589ad685896ab7954855255a52db3442cb38d'\n",
            "Serving signatures: []\n",
            "\n",
            "---\n",
            "\n",
            "Checking for signature_defs using saved_model_cli:\n",
            "\n",
            "2023-04-26 17:12:32.367522: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
            "Traceback (most recent call last):\n",
            "  File \"/usr/local/bin/saved_model_cli\", line 8, in <module>\n",
            "    sys.exit(main())\n",
            "  File \"/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 1284, in main\n",
            "    app.run(smcli_main)\n",
            "  File \"/usr/local/lib/python3.9/dist-packages/absl/app.py\", line 308, in run\n",
            "    _run_main(main, args)\n",
            "  File \"/usr/local/lib/python3.9/dist-packages/absl/app.py\", line 254, in _run_main\n",
            "    sys.exit(main(argv))\n",
            "  File \"/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 1282, in smcli_main\n",
            "    args.func()\n",
            "  File \"/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 961, in show\n",
            "    _show_inputs_outputs(\n",
            "  File \"/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 345, in _show_inputs_outputs\n",
            "    inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(\n",
            "  File \"/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 306, in _get_inputs_tensor_info_from_meta_graph_def\n",
            "    raise ValueError(\n",
            "ValueError: Could not find signature \"serving_default\". Please choose from: __saved_model_init_op\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kKqqX2LsReNz"
      },
      "source": [
        "Since the model we downloaded did not include any serving signatures, we'll re-export it with serving signatures defined.\n",
        "\n",
        "* https://www.tensorflow.org/guide/saved_model#specifying_signatures_during_export"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OlDG2OuqOBGC",
        "outputId": "c25d0e59-3a42-4f43-804c-c607eb9fc84c"
      },
      "source": [
        "#@title Look up input signatures to use when exporting\n",
        "\n",
        "# To save serving signatures we need to specify a `ConcreteFunction` with a\n",
        "# TensorSpec signature. We can determine what this signature should be by\n",
        "# looking at any documentation for the model or running the saved_model_cli.\n",
        "\n",
        "!saved_model_cli show --dir {model_path} --all \\\n",
        "    2> /dev/null | grep \"inputs: TensorSpec\" | tail -n 1"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "          inputs: TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gnb4HhMmkgiT",
        "outputId": "d5ff3d4a-0483-476e-af6e-0b3c827d4938"
      },
      "source": [
        "#@title Re-export the model using the known signature\n",
        "\n",
        "# Get a concrete function using the signature we found above.\n",
        "# \n",
        "# The first element of the shape is a dynamic batch size. We'll be running\n",
        "# inference on a single image at a time, so set it to `1`. The rest of the\n",
        "# shape is the fixed image dimensions [width=224, height=224, channels=3].\n",
        "call = loaded_model.__call__.get_concrete_function(tf.TensorSpec([1, 224, 224, 3], tf.float32))\n",
        "\n",
        "# Save the model, setting the concrete function as a serving signature.\n",
        "# https://www.tensorflow.org/guide/saved_model#saving_a_custom_model\n",
        "resaved_model_path = '/tmp/resaved_model'\n",
        "tf.saved_model.save(loaded_model, resaved_model_path, signatures=call)\n",
        "clear_output()  # Skip over TensorFlow's output.\n",
        "print(f\"Saved model with serving signatures to '{resaved_model_path}'\")\n",
        "\n",
        "# Load the model back into memory and check that it has serving signatures now\n",
        "reloaded_model = tf.saved_model.load(resaved_model_path)\n",
        "reloaded_serving_signatures = list(reloaded_model.signatures.keys())\n",
        "print(f\"\\nReloaded SavedModel from '{resaved_model_path}'\")\n",
        "print(f\"Serving signatures: {reloaded_serving_signatures}\")\n",
        "\n",
        "# Also check with the saved_model_cli:\n",
        "print(\"\\n---\\n\")\n",
        "print(\"Checking for signature_defs using saved_model_cli:\\n\")\n",
        "!saved_model_cli show --dir {resaved_model_path} --tag_set serve --signature_def serving_default"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Saved model with serving signatures to '/tmp/resaved_model'\n",
            "\n",
            "Reloaded SavedModel from '/tmp/resaved_model'\n",
            "Serving signatures: ['serving_default']\n",
            "\n",
            "---\n",
            "\n",
            "Checking for signature_defs using saved_model_cli:\n",
            "\n",
            "2023-04-26 17:13:06.873761: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
            "The given SavedModel SignatureDef contains the following input(s):\n",
            "  inputs['inputs'] tensor_info:\n",
            "      dtype: DT_FLOAT\n",
            "      shape: (1, 224, 224, 3)\n",
            "      name: serving_default_inputs:0\n",
            "The given SavedModel SignatureDef contains the following output(s):\n",
            "  outputs['output_0'] tensor_info:\n",
            "      dtype: DT_FLOAT\n",
            "      shape: (1, 1001)\n",
            "      name: StatefulPartitionedCall:0\n",
            "Method name is: tensorflow/serving/predict\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YdmgASzwanSz"
      },
      "source": [
        "### Import and compile the SavedModel with IREE"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GLkjlHE5mdmg",
        "outputId": "c67419f8-94de-4335-ddbc-f062b7d2e48a"
      },
      "source": [
        "#@title Import from SavedModel\n",
        "\n",
        "# The main output file from compilation is a .vmfb \"VM FlatBuffer\". This file\n",
        "# can used to run the compiled model with IREE's runtime.\n",
        "output_file = os.path.join(ARTIFACTS_DIR, \"mobilenet_v2.vmfb\")\n",
        "# As compilation runs, dump an intermediate .mlir file for future inspection.\n",
        "iree_input = os.path.join(ARTIFACTS_DIR, \"mobilenet_v2_iree_input.mlir\")\n",
        "\n",
        "# Since our SavedModel uses signature defs, we use `saved_model_tags` with\n",
        "# `import_type=\"SIGNATURE_DEF\"`. If the SavedModel used an object graph, we\n",
        "# would use `exported_names` with `import_type=\"OBJECT_GRAPH\"` instead.\n",
        "\n",
        "# We'll set `target_backends=[\"vmvx\"]` to use IREE's reference CPU backend.\n",
        "# We could instead use different backends here, or set `import_only=True` then\n",
        "# download the imported .mlir file for compilation using native tools directly.\n",
        "\n",
        "tfc.compile_saved_model(\n",
        "    resaved_model_path,\n",
        "    output_file=output_file,\n",
        "    save_temp_iree_input=iree_input,\n",
        "    import_type=\"SIGNATURE_DEF\",\n",
        "    saved_model_tags=set([\"serve\"]),\n",
        "    target_backends=[\"vmvx\"])\n",
        "clear_output()  # Skip over TensorFlow's output.\n",
        "\n",
        "print(f\"Saved compiled output to '{output_file}'\")\n",
        "print(f\"Saved iree_input to      '{iree_input}'\")"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Saved compiled output to '/tmp/iree/colab_artifacts/mobilenet_v2.vmfb'\n",
            "Saved iree_input to      '/tmp/iree/colab_artifacts/mobilenet_v2_iree_input.mlir'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        },
        "id": "IEJAzOb5qASI",
        "outputId": "9a29aa51-b99d-4acd-dae8-0d97cf9786e6"
      },
      "source": [
        "#@title Download compilation artifacts\n",
        "\n",
        "ARTIFACTS_ZIP = \"/tmp/mobilenet_colab_artifacts.zip\"\n",
        "\n",
        "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n",
        "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n",
        "\n",
        "# Note: you can also download files using the file explorer on the left\n",
        "try:\n",
        "    from google.colab import files\n",
        "    print(\"Downloading the artifacts zip file...\")\n",
        "    files.download(ARTIFACTS_ZIP)\n",
        "except ImportError:\n",
        "    print(\"Missing google_colab Python package, can't download files\")"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Zipping '/tmp/iree/colab_artifacts' to '/tmp/mobilenet_colab_artifacts.zip' for download...\n",
            "  adding: mobilenet_v2.vmfb (deflated 8%)\n",
            "  adding: mobilenet_v2_iree_input.mlir (deflated 46%)\n",
            "Downloading the artifacts zip file...\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_18545900-47df-4250-9a14-8453ca4b6fc2\", \"mobilenet_colab_artifacts.zip\", 41434352)"
            ]
          },
          "metadata": {}
        }
      ]
    }
  ]
}
